import json
import pathlib
from sklearn.metrics import auc, precision_recall_curve, roc_curve
import pdb
import math
import numpy as np
from scipy.stats import iqr
import time 
import os

video_auc = {}

def gaussian_kernel_original(x, mu, sigma):
    return np.exp(-0.5 * ((x - mu) / sigma) ** 2) 

def gaussian_smoothing(data, sigma):
    n = len(data)
    smoothed_data = np.zeros(n)
    x = np.arange(n)

    centroid_index = int(n/2)

    kernel_values = gaussian_kernel_original(x, centroid_index, sigma)
    smoothed_data =  kernel_values * data  

    return smoothed_data


def gaussian_kernel(size, sigma):
    kernel = np.exp(-np.linspace(-size//2, size//2, size)**2 / (2*sigma**2))
    return kernel / kernel.sum() 

def gaussian_smooth_1d(data, size=5, sigma=1.0):
    kernel = gaussian_kernel(size, sigma)
    smoothed_data = np.convolve(data, kernel, mode='same')
    return smoothed_data

def softmax(row_vector):
    row_vector = row_vector
    exp_row = np.exp(row_vector - np.max(row_vector))
    
    return exp_row / np.sum(exp_row)

def apply_window(binary_list, window_size=8):
    arr = np.array(binary_list)
    ones_indices = np.where(arr == 1)[0]
    
    for idx in ones_indices:
        arr[max(0, idx - window_size + 1): min(len(arr), idx + window_size)] = 1
        
    return arr.tolist()

def apply_window_rule(lst, window_size=4):

    n = len(lst)
    new_lst = lst[:]  

    for i in range(n):
        left = max(0, i - window_size // 2)
        right = min(n, i + window_size // 2 + 1)

        if 1 in lst[left:right]:
            new_lst[i] = 1

    return new_lst


def find_non_uniform_neighbors(all_score,window_size):

    n = len(all_score)
    result = []

    padded_score = [all_score[0]] * int((window_size-1)/2) + all_score + [all_score[-1]] * int((window_size-1)/2)

    for i in range(n):
        neighbors = padded_score[i:i + window_size]

        if not all(x == neighbors[0] for x in neighbors):
            result.append(i)

    return result


def get_softmax_and_entropy(logits):

  softmax_scores = F.softmax(logits, dim=-1)
  log_probs = F.log_softmax(logits, dim=-1)
  entropy = -torch.sum(softmax_scores * log_probs, dim=-1)

  return softmax_scores, entropy

score_folder_internvl2 = "path/to/base/model/predictions/"
score_folder_multiple_agent = 'path/to/base/supervisor/predictions/'

 
ann_root = 'data/UCF_Eval.json'
vision_folder = 'ucf_crime/vision_features/'
emd_folder = 'ucf_crime/sinkhorn_loss_no_cls/'

BINARY_SCORE_FLAG = False
result = {}
result_multiple_agent = {}
result_internvl2 = {}
 

result2 = {}
result3 = {}

auc_results = []
all_predict_score = []
all_gt = []
 
avg_l2 = []
name = []



data_path = pathlib.Path(ann_root)
with data_path.open(encoding='utf-8') as f:
    annotation = json.load(f)

for v_i in range(len(annotation)):
    sample =  annotation[v_i]
    key = sample['video'].split('/')[-1].split('.')[0]
    print(key)
 
    score_path = score_folder +key+'.json'
    with open(score_path,'r') as f: score = json.load(f)
    result[key] = score

    score_path_single = score_folder_internvl2 +key+'.json'
    with open(score_path_single,'r') as f: score = json.load(f)
    result_internvl2[key] = score

    score_path_multiple = score_folder_multiple_agent +key+'.json'
    with open(score_path_multiple,'r') as f: score = json.load(f)
    result_multiple_agent[key] = score
 
    if BINARY_SCORE_FLAG:
        pred_score = []
        sampling_rate = 16
        start_index = [key for key in result[key].keys()]
        for i, start in enumerate(start_index):
            if i != len(start_index)-1 :
                num_ele = sampling_rate
                seg_list = [result[key][start]['score']]*num_ele
                pred_score.extend(seg_list)
            else:
                num_ele = int(result[key][start]['end']-int(start))
                seg_list = [result[key][start]['score']]*num_ele
                pred_score.extend(seg_list)

        
    else:

        
        v_array = np.load(vision_folder+key+'.npy')
        emd_matrix = np.load(emd_folder+key+'.npy')
        
        v_norms = np.linalg.norm(v_array, axis=1, keepdims=True)
        normalized_v = v_array / v_norms
        

        v_array = np.dot(normalized_v, normalized_v.T)
    
        num_segment = v_array.shape[0]
        top_n =  int(0.15*num_segment)

    
        v_indices = np.argsort(v_array, axis=1)[:, -top_n:]
        v_top_values = np.take_along_axis(v_array, v_indices, axis=1)
        
        pred_score = []
        sampling_rate = 16
        start_index = [key for key in result[key].keys()]

        all_score = []
        candidate_score = []
        for i, start in enumerate(start_index):
 
            candidate_score.append(result_multiple_agent[key][start]['score'])
            
            
            all_score.append(result_internvl2[key][start]['score']) 
        
        if max(all_score) == 1.0:
            all_score = np.asarray(all_score)
            one_indices = np.where(all_score != 0)[0]
            
            all_score = all_score.tolist()
        
        temporal_inconsist_pos = find_non_uniform_neighbors(all_score,5)
        if temporal_inconsist_pos != []:
            for pos in temporal_inconsist_pos:
                all_score[pos] = candidate_score[pos]


        entropy_list = []
        for i, start in enumerate(start_index):

            neighbor_score = []
            for neighbor in v_indices[i].tolist():
                neighbor_score+=[all_score[neighbor]]
            neighbor_score = np.array(neighbor_score) 

            unique_elements, counts = np.unique(neighbor_score, return_counts=True)
            probabilities = counts / len(neighbor_score)

            if len(counts) == 0:
                entropy = 0
            else:
 
                entropy = -np.sum(probabilities * np.log2(probabilities))
            entropy_list.append(entropy)

        sorted_list = sorted(entropy_list, reverse=False)
        index_at_three_fourth = int(0.5 * (len(sorted_list) - 1))  
        three_fourth_value = sorted_list[index_at_three_fourth]
        greater_or_equal = [ind for (ind, x) in enumerate(entropy_list) if x >= three_fourth_value]

    
        score_all = []
        entropy_list = []
        for i, start in enumerate(start_index):

            if i in greater_or_equal:
                neighbor_score = []
                for neighbor in v_indices[i].tolist():
                    neighbor_score+=[candidate_score[neighbor]]
                neighbor_score = np.array(neighbor_score) 
            else:

                neighbor_score = []
                for neighbor in v_indices[i].tolist():
                    neighbor_score+=[all_score[neighbor]]
                neighbor_score = np.array(neighbor_score) 


            v_softmax_values = softmax(v_top_values[i]*1) 
    
            v_segment_score =  np.dot(v_softmax_values, neighbor_score)
        
            score_all.append(v_segment_score)
        
       
 
        smoothed_data = score_all
        smoothed_data = gaussian_smooth_1d(np.array(smoothed_data), 20, 10)


        for i, start in enumerate(start_index):

            
            v_segment_score =   smoothed_data[i]

            segment_score = round(v_segment_score, 1)
            score_all.append(segment_score)
            if i != len(start_index)-1 :
                num_ele = sampling_rate
                seg_list = [segment_score]*num_ele
                pred_score.extend(seg_list)
            else:
                num_ele = int(result[key][start]['end']-int(start))
 
                seg_list = [segment_score]*num_ele
                pred_score.extend(seg_list)


        sigma = int(len(pred_score)*0.5)  
        pred_score = gaussian_smoothing(np.array(pred_score), sigma)
 
    gt = [0.0]*annotation[v_i]['length']
    for anno_i in range(0, len(annotation[v_i]['temporal_label']),2):
        if annotation[v_i]['temporal_label'][anno_i] != -1:
            anno_s = annotation[v_i]['temporal_label'][anno_i]
            if annotation[v_i]['temporal_label'][anno_i+1] > annotation[v_i]['length']:
                anno_e = annotation[v_i]['length']
            else:
                anno_e = annotation[v_i]['temporal_label'][anno_i+1]

            gt[anno_s:anno_e] = [1.0]*(anno_e-anno_s)

 
    all_predict_score.extend(pred_score)
    all_gt.extend(gt)

fpr, tpr, threshold = roc_curve(all_gt, all_predict_score)
roc_auc = auc(fpr, tpr)
   

print(roc_auc)